Add dtype parameter to kspaceFirstOrder() (#695)#716
Conversation
Exposes precision control on the modern unified API to match the legacy SimulationOptions.data_cast and MATLAB k-Wave's DataCast. data_cast='off' -> np.float64 (default; matches legacy) data_cast='double' -> np.float64 (alias for 'off', MATLAB compat) data_cast='single' -> np.float32 (~half memory, faster, lower accuracy) Python backend: plumbs through Simulation, which now stores self._dtype and uses it for all state arrays (p, u, rho_split, sensor_data buffers, PML arrays, alpha_coeff/BonA/p0 expansions, source signal buffers). Default behavior unchanged (float64 everywhere). C++ backend: data_cast has no effect — the binary uses fixed internal precision regardless of HDF5 input dtype. Setting anything other than 'off'/'double' with backend='cpp' emits a UserWarning explaining this and pointing users at backend='python' for precision control. Tests: 8 new in tests/test_data_cast.py covering output dtype matches request, default behavior unchanged, invalid value raises, single vs double numerical agreement within float32 tolerance, and the C++ warn/silent paths. Wider suite (62 tests across native_solver, ivp_homogeneous, issue_664) still passes. Closes #695. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #716 +/- ##
==========================================
+ Coverage 75.04% 75.34% +0.29%
==========================================
Files 57 57
Lines 8128 8164 +36
Branches 1584 1593 +9
==========================================
+ Hits 6100 6151 +51
+ Misses 1405 1392 -13
+ Partials 623 621 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Make the precision parameter Pythonic instead of MATLAB-stringly-typed.
The numpy ecosystem's convention is to accept dtype-like inputs broadly
(numpy types, strings, dtype objects), and the modern API should
follow that idiom rather than the legacy SimulationOptions.data_cast
naming.
Accepted forms (resolved via _resolve_dtype, which uses np.dtype()):
None / np.float64 / "float64" / "double" / float / "off" / np.dtype("f8")
-> np.float64 (default)
np.float32 / "float32" / "single" / np.dtype("f4")
-> np.float32
The MATLAB aliases ("off", "single", "double") are kept as compat
shortcuts so users porting from the legacy API or MATLAB k-Wave have
zero friction. Anything resolving to a non-float32/float64 type
(np.float16, np.complex64, etc.) raises ValueError -- the solver
isn't validated for those.
C++ backend warns when dtype is not np.float64 (binary uses fixed
internal precision regardless).
Tests: 22 (was 8) parametrized over every input form. Wider suite
(62 tests) still passes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@greptile-apps re-review |
Greptile P2 (test for p_final dtype) caught a real bug: with
dtype='single', p_final came back as float64 even though sensor_data
buffers (p, p_max, p_min, p_rms) were correctly float32.
Root cause: two sources of float64 leaking into the hot loop:
1. xp.fft.fftfreq returns float64; k_list, kappa, op_grad/div_list,
_k_mag inherited it. _diff's FFT round-trip (float64 op * complex64
field) upcasts to complex128, .real => float64. Result: self.p and
self.u rebound to float64 mid-step() despite being allocated as
float32. p_final = self.p[interior].copy() picked up float64.
sensor_data buffers stayed float32 because writes are in-place into
the pre-allocated buffer (silent narrowing on assignment).
2. get_pml returns float64 unconditionally; the per-step pml multiply
was a second upcast path independent of (1).
Both cast sites now apply .astype(self._dtype) at construction time,
keeping the entire compute pipeline in the user's requested precision.
Test updated: float32 / float64 input parametrizations now request
('p', 'p_final', 'p_max', 'p_min', 'p_rms') and assert every field's
dtype matches. Verified: float32 inputs => all five fields float32;
float64 => all five float64.
Bonus: helpful error for torch / jax / tensorflow dtype objects via
duck-typed __module__ check (no framework imports needed); cupy works
for free since cp.float32 is np.float32.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…odern-api # Conflicts: # kwave/solvers/kspace_solver.py
Greptile spotted a third dtype-drift path: ``sum(rho_split)`` in ``_nl_factor`` and the equation-of-state line starts with Python ``int 0``. Under numpy < 2 (NEP 50), ``int + float32 -> float64``, so: nl_factor = (2 * sum(rho_split) + rho0) / rho0 is float64 even when rho_split is float32. The product ``rho0 * div_u_i * nl_factor`` in mass conservation then upcasts the rho_split arrays to float64 on the very first step. Specifically affects any simulation that enables BonA. Fix: ``_array_sum`` helper that starts the accumulator from ``arrays[0]`` so the dtype is preserved. Used in both call sites (_nl_factor lambda and equation-of-state rho_total). Test added: test_python_backend_dtype_preserved_with_nonlinearity exercises the BonA path with dtype=np.float32 and asserts p / p_final / p_max all remain float32. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@greptile-apps re-review |
End-to-end verification under numpy 1.26.4 (with my dtype tests requesting
('p', 'p_final', 'p_max', 'p_min', 'p_rms')) showed self.p still upcast to
float64 mid-step despite all the prior precision fixes.
Root cause: numpy < 2's `np.fft.fftn` always returns complex128 regardless
of input precision -- a known difference resolved in numpy 2 (NEP 50 era).
The k-space ops (op_grad/div_list, unstagger_ops) being complex64 isn't
enough; multiplying complex64 by complex128 upcasts to complex128, and
.real -> float64 propagates back into self.p.
Fix: introduce self._complex_dtype (complex64 / complex128 matching
self._dtype), and cast every fftn / ifftn .real result back to the
intended precision. Three call sites: step() momentum loop, sensor-data
unstagger, _diff helper. ``copy=False`` makes the cast a no-op when the
dtype already matches (numpy 2, cupy).
Greptile flagged dt_over_rho0, kappa, and sum() as P1 dtype-promotion
paths. Verified empirically on numpy 1.26.4 that none of those upcast in
practice (Python float / int / complex are weak types in numpy 1.x as
well) -- but the test failures Greptile would have seen had the same
visible signature, so the diagnosis pointed at the right region.
Tests: 86 (24 dtype + 62 wider suite) pass on numpy 1.26.4 AND numpy 2.2.6.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Greptile flagged But Greptile was right that the diff was unsafe on numpy<2 — the actual culprit was different: Fixed in 22ec78d: So same conclusion (not safe on numpy<2 before this commit), different mechanism. Re-triggering review now. @greptile-apps re-review |
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | ||
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] |
There was a problem hiding this comment.
unstagger_ops also built with 1j * literal → always complex128 on numpy < 2
The same Python-literal promotion issue as op_grad_list/op_div_list applies here: -1j * self.k_list[ax] is complex128 on numpy < 2, regardless of k_list dtype. Cast to self._complex_dtype after construction to keep the unstaggering FFT in the correct precision.
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] | |
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)] |
There was a problem hiding this comment.
Verified empirically on numpy 1.26.4: -1j * k_f32 * spacing / 2 evaluates to complex64 and np.exp(...) preserves complex64 — Python complex is a weak type just like Python float/int, so the multiplication chain stays in complex64 when k_list is float32. So unstagger_ops is already complex64 without an explicit cast.
Additionally, the prior commit (22ec78d) already added .astype(self._dtype, copy=False) to the staggered ifftn .real line in step(), so the recorded output stays float32 even if unstagger_ops were ever to change precision.
Added a regression test (d545952) that records ('p', 'ux_staggered', 'uy_staggered') with dtype=np.float32 and asserts each is float32 — locks the behavior in for both numpy 1.26.4 and 2.2.6.
@greptile-apps re-review
There was a problem hiding this comment.
The empirical result doesn't match what NEP 50's own migration table documents. The NEP 50 spec explicitly shows the old (numpy < 2) vs new (numpy ≥ 2 / NEP 50) behaviour:
| Expression | numpy < 2 (old) | numpy ≥ 2 (NEP 50) |
|---|---|---|
(3j + array(3, complex64)).dtype |
complex128 |
complex64 |
(float32(1) + 1j).dtype |
complex128 |
complex64 |
Source: https://numpy.org/neps/nep-0050-scalar-promotion.html
Python complex literals are not weak types in numpy < 2 — they behave like np.complex128 in the type promotion system. The "weak scalar" treatment is exactly what NEP 50 introduced. So on numpy 1.26.4, -1j * k_f32 * spacing / 2 → complex128, and np.exp(...) of a complex128 array → complex128.
It's possible the empirical test ran against a numpy 2.x environment, or that k_list[ax] was already float64 at the point of measurement (defeating the test). The regression test added in d545952 checks recorded sensor_data["p"] and "ux_staggered" dtypes, but those go through the .astype(self._dtype, copy=False) guard added in step() — so they'll pass regardless of what unstagger_ops itself is. The dtype of unstagger_ops isn't directly asserted.
The fix remains: cast the list comprehension result to self._complex_dtype:
| # Spectral shift: move velocity from staggered (mid-cell) to collocated (pressure) grid | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2) for ax in range(self.ndim)] | |
| self.unstagger_ops = [xp.exp(-1j * self.k_list[ax] * self.spacing[ax] / 2).astype(self._complex_dtype) for ax in range(self.ndim)] |
Tip: You can give me more context with .greptile/rules.md so I get this right next time.
Greptile flagged unstagger_ops as a P1 (claimed 1j * k_f32 -> complex128
on numpy<2). Empirically that's not true under numpy 1.26.4 -- Python
complex is a weak type just like Python float/int, so unstagger_ops
stays complex64 when k_list is float32. The recorded ux_staggered output
also stays float32 because the prior commit already cast the unstagger
ifftn .real result to self._dtype.
Adding a test that records ('p', 'ux_staggered', 'uy_staggered') and
asserts each is float32 under dtype=np.float32 -- guards against any
future regression in the staggered-velocity path on either numpy major
version. 25 tests pass on numpy 1.26.4 and numpy 2.2.6.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@greptile-apps re-review |
|
@copilot resolve the merge conflicts in this pull request |
Co-authored-by: waltsims <8669206+waltsims@users.noreply.github.com>
Resolved. I merged |
|
@greptile re-review |
|
Closes #338 |
| self._nonlinearity = lambda rho: self.BonA * rho**2 / (2 * self.rho0) | ||
| self._nl_factor = lambda rho_split: (2 * sum(rho_split) + self.rho0) / self.rho0 | ||
| self._nl_factor = lambda rho_split: (2 * _array_sum(rho_split) + self.rho0) / self.rho0 |
There was a problem hiding this comment.
int 2 literal re-introduces dtype promotion on numpy < 2
_array_sum was added to prevent sum(rho_split) starting from Python int 0, but 2 * _array_sum(rho_split) still multiplies by a Python int — which numpy < 2 (pre-NEP 50) treats as np.int64 (a strong type). np.result_type(np.int64, np.float32) → np.float64, so nl_factor is float64 when BonA is enabled on numpy 1.x, and propagates through rho_split[i] into self.p → p_final. The same int 2 divisor in _nonlinearity (2 * self.rho0) causes the same promotion for the equation-of-state nonlinear term. On numpy >= 2 (NEP 50 weak scalars) the tests pass, but on numpy 1.26 they will silently produce float64 p_final instead of float32.
| self._nonlinearity = lambda rho: self.BonA * rho**2 / (2 * self.rho0) | |
| self._nl_factor = lambda rho_split: (2 * sum(rho_split) + self.rho0) / self.rho0 | |
| self._nl_factor = lambda rho_split: (2 * _array_sum(rho_split) + self.rho0) / self.rho0 | |
| _two = self.xp.asarray(2, dtype=self._dtype) | |
| self._nonlinearity = lambda rho: self.BonA * rho**2 / (_two * self.rho0) | |
| self._nl_factor = lambda rho_split: (_two * _array_sum(rho_split) + self.rho0) / self.rho0 |
Greptile's dt_over_rho0 / rho_split claim is empirically false on numpy 1.26.4Greptile's review claims that Direct verification with the freshly-installed environment (numpy 1.26.4): import numpy as np
dt = 1.23e-9 # Python float (what self.dt is)
rho = np.full((4,4), 1000.0, dtype=np.float32)
print((dt / rho).dtype) # float32 (line 663)
rho0 = np.full((4,4), 1000.0, dtype=np.float32)
div_u_i = np.full((4,4), 0.5, dtype=np.float32)
nl_factor = 1.0
print((dt * rho0 * div_u_i * nl_factor).dtype) # float32 (line 716, linear)
rho_split = [np.full((4,4), 0.1, dtype=np.float32) for _ in range(2)]
nl_factor_nl = (2*rho_split[0] + 2*rho_split[1] + rho0) / rho0
print((dt * rho0 * div_u_i * nl_factor_nl).dtype) # float32 (line 716, nonlinear)The full NEP 50 changed numpy scalar dtype rules ( |
|
@greptile re-review |
Closes #695.
What
Exposes precision control on the modern unified API. Pythonic / numpy-idiomatic naming and accepted input forms:
None(default)np.float64np.float64/"float64"/"double"/float/np.dtype("f8")np.float64np.float32/"float32"/"single"/np.dtype("f4")np.float32"off"(legacy MATLAB alias)np.float64np.float16,np.complex64,"quad", …)ValueErrorThe MATLAB aliases (
"off","single","double") are kept as compat shortcuts for users porting from the legacySimulationOptions.data_castor MATLAB k-Wave'sDataCast. Everything else usesnp.dtype()for normalization, matching the broader numpy/scipy/torch convention.Why
dtypeinstead ofdata_castdata_castis a MATLAB term. The numpy ecosystem (numpy, pandas, jax, torch) usesdtypeand accepts dtype-like inputs broadly. The modern unified API is a fresh design — it should follow the Python idiom rather than the MATLAB one. The MATLAB-style strings still work, so MATLAB users lose nothing.How
Python backend plumbs
dtypethroughSimulation, which now storesself._dtypeand uses it for every state-array allocation:p,u,rho_split, sensor-data buffers, PML arrays, source signal buffers, and the_expand_to_gridhelper forsound_speed/density/alpha_coeff/BonA/p0. Default behavior unchanged (float64 everywhere).C++ backend intentionally has no effect — the binary uses fixed internal precision regardless of HDF5 input dtype. Setting
dtypeto anything other thannp.float64withbackend='cpp'emits aUserWarningexplaining this and pointing users atbackend='python'for precision control.Test plan
New file
tests/test_data_cast.py(22 tests):test_python_backend_float64_inputsparametrized over[None, np.float64, "float64", "double", float, "off", np.dtype("f8")]— every form resolves to float64 outputtest_python_backend_float32_inputsparametrized over[np.float32, "float32", "single", np.dtype("f4")]— every form resolves to float32 outputtest_default_dtype_is_float64— calling without the kwarg gives float64 (back-compat)test_invalid_dtype_raisesparametrized over[np.float16, np.complex64, "float16", "complex64", "quad", 42, "garbage"]— all raiseValueErrortest_python_single_vs_double_numerical_agreement— single and double runs agree to within1e-4relative errortest_cpp_backend_warns_on_non_float64_dtype—UserWarningfires before binary runstest_cpp_backend_silent_on_default_dtype— no warning on defaultWider suite verified (62 tests):
test_native_solver,test_ivp_homogeneous_medium,test_issue_664_alpha_power_near_unityall pass.Greptile Summary
This PR adds a
dtypeparameter tokspaceFirstOrder()that lets callers control state-array precision for the Python backend (np.float32ornp.float64), with MATLAB-style string aliases ("single","double","off") kept for migration compatibility. The C++ backend correctly ignores the parameter and emits aUserWarningwhen a non-float64 value is passed._resolve_dtypenormalises every dtype-like input form and raisesValueErrorwith framework-specific hints for torch/jax objects. Default behaviour (float64 everywhere) is unchanged.Simulationnow storesself._dtypeandself._complex_dtype, plumbed through field allocation, PML arrays, k-vectors, sensor buffers, and source operators; several FFT round-trips gained explicit.astype()casts.kappa/source_kappanot cast back toself._dtype;dt_over_rho0stillfloat(self.dt) / float32_rho;self.dt * self.rho0on therho_splitupdate likewise uncast. On numpy>=2 all tests pass; on numpy 1.x these silently produce float64 arrays even whendtype=np.float32.Confidence Score: 3/5
Safe to merge for numpy>=2 environments; float32 precision is not reliably enforced on numpy 1.x due to several uncast Python-scalar multiplications identified in prior review rounds.
The dtype-plumbing is thorough for array allocations, PML, k-vectors, and sensor buffers, and the FFT cast-back guards prevent most output-dtype drift. However, kappa/source_kappa construction (Python-float c_ref times float32 k_mag), dt_over_rho0 (Python float divided by float32), and the rho_split update (self.dt * self.rho0) still produce float64 intermediates on numpy 1.x, silently defeating the float32 promise on that platform.
kwave/solvers/kspace_solver.py — _setup_kspace_operators (kappa/source_kappa), dt_over_rho0 precomputation in _setup_fields, and rho_split update in step() all have uncast Python-scalar multiplications that defeat float32 precision on numpy 1.x.
Important Files Changed
Comments Outside Diff (7)
kwave/solvers/kspace_solver.py, line 663 (link)dt_over_rho0computed as Pythonfloat/float32→ silently float64 on numpy < 2self.dtis stored asfloat(self.kgrid.dt)(a Python float, equivalent tofloat64). Dividing by afloat32array gives afloat64result under numpy < 2 (NEP 50 changed this in numpy 2.0). As a consequence, on everystep():self.dt_over_rho0[i] * grad_p_i(line 706) isfloat64 × float32 → float64, soself.u[i]is rebound to afloat64array after the very first step. The same Python-scalar promotion also affects line 716 (self.dt * self.rho0 * div_u_i * nl_factor), sorho_split[i]and, throughrho_total,self.palso becomefloat64. The sensor-data buffer (sensor_data["p"]) is pre-allocatedfloat32and silently narrows values on in-place assignment, soresult["p"]tests pass, butresult["p_final"](line 773) isself.p[interior].copy()— no narrowing — and will befloat64even whendtype=np.float32is requested, breaking the dtype contract on numpy < 2.kwave/solvers/kspace_solver.py, line 429-447 (link)complex128on numpy < 2k_listentries are now correctly cast toself._dtype(float32 when requested), butself.c_refandself.dtare Python floats (float64), soself.c_ref * k_mag * self.dt / 2isfloat64on numpy < 2, makingkappaandsource_kappafloat64. The1j *Python complex literal then forcesop_grad_listandop_div_listtocomplex128rather thancomplex64. Castkappa/source_kappatoself._dtypeand the final operators toself._complex_dtypeafter construction.kwave/solvers/kspace_solver.py, line 661-663 (link)self.dtis stored asfloat(self.kgrid.dt)— a Pythonfloat, which numpy < 2.0 (pre-NEP 50) treats asnp.float64in type promotion. Dividing afloat64scalar by afloat32array yieldsfloat64on numpy 1.x, sodt_over_rho0is float64 even whenself._dtype is np.float32. On the firststep(),self.dt_over_rho0[i] * grad_p_i(float64 × float32 → float64) and the outerpml_sg * (...)product makeself.u[i]float64. The same Python-scalar promotion on line 716 (self.dt * self.rho0 * div_u_i) then forcesself.rho_split[i]to float64, which propagates through_array_sum(rho_split)intoself.p. Becauseresult["p_final"]is a direct.copy()ofself.p(no pre-allocated float32 buffer to narrow into), it will be float64 on numpy < 2 even whendtype=np.float32— causingtest_python_backend_float32_inputsto fail on numpy 1.x. Pre-castdttoself._dtypeat setup time and reuse it instep().kwave/solvers/kspace_solver.py, line 716 (link)self.dtis a Pythonfloat(float64), soself.dt * self.rho0(float64 × float32) → float64 on numpy < 2.nl_factor = 1.0(Python float) compounds this in the linear path. The result isrho_split[i]ends up as float64, which propagates throughrho_totalintoself.p, makingp_finalfloat64 regardless ofself._dtype. Replaceself.dtwithself._dt_typed(the dtype-cast scalar computed during setup).kwave/solvers/kspace_solver.py, line 663 (link)dt_over_rho0is computed by dividing a Pythonfloat(self.dt = float(kgrid.dt)) by afloat32array. On numpy < 2 (pre-NEP 50), Python scalars are strongnp.float64, soself.dt / rhoyields afloat64result for each element of the list. At line 706,self.dt_over_rho0[i] * grad_p_i(float64 × float32) then rebindsself.u[i]to float64 — there is no.astype()guard on that assignment, unlike the_diff()return path.u_final(line 777) is thenself.u[i][interior].copy(), which will be float64 even whendtype=np.float32.kwave/solvers/kspace_solver.py, line 716 (link)self.dtis a Pythonfloat(float64). On numpy < 2,float64 × float32is a strong-type promotion to float64, soself.dt * self.rho0 * div_u_i * nl_factorevaluates to float64 and rebindsself.rho_split[i]to a float64 array — there is no.astype()narrowing guard here. Float64rho_splitentries then flow through_array_sum(self.rho_split)→self.p, makingp_finalfloat64 even whendtype=np.float32.kwave/solvers/kspace_solver.py, line 429-431 (link)self.c_refandself.dtare both Pythonfloat(= float64). On numpy < 2,float64 * float32_arrayis a strong promotion, soself.c_ref * k_mag * self.dt / 2produces a float64 array even thoughk_magwas cast toself._dtype.kappaandsource_kappaare therefore float64, which in turn forcesop_grad_list/op_div_listtocomplex128(Python1jis also a strong complex128 on numpy < 2). Every call to_diff()then runs the FFT round-trip in float64 arithmetic even whendtype=np.float32was requested, defeating the purpose of the precision parameter.Reviews (10): Last reviewed commit: "Merge branch 'master' into feature-data-..." | Re-trigger Greptile